d9343f
@@ -139,6 +139,7 @@
public void onMatch(RelOptRuleCall call) {
 
     if (numCountDistinct > 1 && numCountDistinct == aggregate.getAggCallList().size()
         && aggregate.getGroupSet().isEmpty()) {
+      LOG.debug("Trigger countDistinct rewrite. numCountDistinct is " + numCountDistinct);
       // now positions contains all the distinct positions, i.e., $5, $4, $6
       // we need to first sort them as group by set
       // and then get their position later, i.e., $4->1, $5->2, $6->3
@@ -186,9 +187,12 @@
public void onMatch(RelOptRuleCall call) {
    * Converts an aggregate relational expression that contains only
    * count(distinct) to grouping sets with count. For example select
    * count(distinct department_id), count(distinct gender), count(distinct
-   * education_level) from employee; can be transformed to select count(case i
-   * when 1 then 1 else null end) as c0, count(case i when 2 then 1 else null
-   * end) as c1, count(case i when 4 then 1 else null end) as c2 from (select
+   * education_level) from employee; can be transformed to 
+   * select 
+   * count(case when i=1 and department_id is not null then 1 else null end) as c0, 
+   * count(case when i=2 and gender is not null then 1 else null end) as c1, 
+   * count(case when i=4 and education_level is not null then 1 else null end) as c2 
+   * from (select
    * grouping__id as i, department_id, gender, education_level from employee
    * group by department_id, gender, education_level grouping sets
    * (department_id, gender, education_level))subq;
@@ -230,13 +234,22 @@
public RexNode apply(RelDataTypeField input) {
           }
         });
     final List<RexNode> gbChildProjLst = Lists.newArrayList();
+    // for singular arg, count should not include null
+    // e.g., count(case when i=1 and department_id is not null then 1 else null end) as c0, 
+    // for non-singular args, count can include null, i.e. (,) is counted as 1
     for (List<Integer> list : cleanArgList) {
-      RexNode equal = rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
-          originalInputRefs.get(originalInputRefs.size() - 1),
-          rexBuilder.makeExactLiteral(new BigDecimal(getGroupingIdValue(list, sourceOfForCountDistinct))));
-      RexNode condition = rexBuilder.makeCall(SqlStdOperatorTable.CASE, equal,
+      RexNode condition = rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, originalInputRefs
+          .get(originalInputRefs.size() - 1), rexBuilder.makeExactLiteral(new BigDecimal(
+          getGroupingIdValue(list, sourceOfForCountDistinct))));
+      if (list.size() == 1) {
+        int pos = list.get(0);
+        RexNode notNull = rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL,
+            originalInputRefs.get(pos));
+        condition = rexBuilder.makeCall(SqlStdOperatorTable.AND, condition, notNull);
+      }
+      RexNode when = rexBuilder.makeCall(SqlStdOperatorTable.CASE, condition,
           rexBuilder.makeExactLiteral(BigDecimal.ONE), rexBuilder.constantNull());
-      gbChildProjLst.add(condition);
+      gbChildProjLst.add(when);
     }
 
     // create the project before GB
